{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Creating a custom loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loss functions are crucial because they define the objective to optimize for during training. Classy Vision can work directly with loss functions defined in [PyTorch](https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html) without the need for any wrapper classes, but during research it's common to create custom losses with hyperparameters. Using `ClassyLoss` you can expose these hyperparameters via a configuration file.\n",
    "\n",
    "This tutorial will demonstrate: \n",
    "1. How to create a custom loss within Classy Vision; \n",
    "2. How to integrate your loss with Classy Vision's configuration system;\n",
    "3. How to use a ClassyLoss independently, without other Classy Vision abstractions.\n",
    "\n",
    "## 1. Defining a loss\n",
    "\n",
    "Creating a new loss in Classy Vision is as simple as adding a new loss within PyTorch. The loss has to derive from `ClassyLoss` (which inherits from [`torch.nn.Module`](https://pytorch.org/docs/stable/nn.html#module)), and implement a `forward` method.\n",
    "\n",
    "> **Note**: The forward method should take the right arguments depending on the task the loss will be used for. For instance, a `ClassificationTask` passes the `output` and `target` to `forward`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classy_vision.losses import ClassyLoss\n",
    "\n",
    "class MyLoss(ClassyLoss):\n",
    "    def __init__(self, alpha):\n",
    "        super().__init__()\n",
    "        self.alpha = alpha\n",
    "    \n",
    "    def forward(self, output, target):\n",
    "        return (output - target).pow(2) * self.alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can start using this loss for training. Take a look at our [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial for more details on how to train a model from a Jupyter notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classy_vision.tasks import ClassificationTask\n",
    "\n",
    "my_loss = MyLoss(alpha=5)\n",
    "my_task = ClassificationTask().set_loss(my_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Integrate it with the configuration system\n",
    "\n",
    "To be able to use the registration mechanism to be able to pick up the loss from a configuration, we need to do two additional things -\n",
    "- Implement a `from_config` method\n",
    "- Add the `register_loss` decorator to `MyLoss`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classy_vision.losses import ClassyLoss, register_loss\n",
    "\n",
    "@register_loss(\"my_loss\")\n",
    "class MyLoss(ClassyLoss):\n",
    "    def __init__(self, alpha):\n",
    "        super().__init__()\n",
    "        self.alpha = alpha\n",
    "\n",
    "    @classmethod\n",
    "    def from_config(cls, config):\n",
    "        if \"alpha\" not in config:\n",
    "            raise ValueError('Need \"alpha\" in config for MyLoss')\n",
    "        return cls(alpha=config[\"alpha\"])\n",
    "        \n",
    "    def forward(self, output, target):\n",
    "        return (output - target).pow(2).sum() * self.alpha"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can start using this loss in our configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classy_vision.losses import build_loss\n",
    "import torch\n",
    "\n",
    "loss_config = {\n",
    "    \"name\": \"my_loss\",\n",
    "    \"alpha\": 5\n",
    "}\n",
    "my_loss = build_loss(loss_config)\n",
    "assert isinstance(my_loss, MyLoss)\n",
    "\n",
    "# ClassyLoss inherits from torch.nn.Module, so it works as expected\n",
    "with torch.no_grad():\n",
    "    y_hat, target = torch.rand((1, 10)), torch.rand((1, 10))\n",
    "    print(my_loss(y_hat, target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that your loss is integrated with the configuration system, you can train it using `classy_train.py` as described in the [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial, no further changes are needed! Just make sure the code defining your model is in the `losses` folder of your classy project."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Conclusion\n",
    "\n",
    "In this tutorial, we learned how to make your loss compatible with Classy Vision and how to integrate it with the configuration system. Refer to our documentation to learn more about [ClassyLoss](https://classyvision.ai/api/losses.html)."
   ]
  }
 ],
 "metadata": {
  "bento_stylesheets": {
   "bento/extensions/flow/main.css": true,
   "bento/extensions/kernel_selector/main.css": true,
   "bento/extensions/kernel_ui/main.css": true,
   "bento/extensions/new_kernel/main.css": true,
   "bento/extensions/system_usage/main.css": true,
   "bento/extensions/theme/main.css": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
